import numpy as np
# import pykalman
import matplotlib.pyplot as plt
import scipy.stats as stats
import ekf_testv6 as ekf
import summary_of_sims as sum_sims
import copy
import matplotlib.transforms as mtransforms

twTR=sum_sims.twTR

ocn_sims = np.genfromtxt(open("OHCcomb.csv", "rb"),dtype=float, delimiter=',')

[sms, smyrs]=np.shape(sum_sims.twTR)
smyrs=ekf.n_iters
#dif=0.21
twTRmean=sum_sims.twTRmean
twTRstd=sum_sims.twTRstd
smdate=1850

dates=ekf.dates
temps=ekf.temps


wt1=2; wt2=1;
colorekfsims=[0,0,0]
colorekfsimsl=[0,0,0]
for c in range(3):
    colorekfsims[c]=(ekf.colorekf[c]*wt1+ekf.colorstate[c]*wt2)/(wt1+wt2)
    colorekfsimsl[c]=int(colorekfsims[c]*255)

colorekfsims="orange"
colorekfsimsl=[215,100,61]

fig = plt.figure(figsize=(9,9))
ax_dict = fig.subplot_mosaic(
"""
ab
ac
de
""",
gridspec_kw={
    "width_ratios": [7, 1.25], "height_ratios": [3,3,3], "wspace": 0.35, "hspace": 0.45, "left":0.1, "right":0.96
},)

fig2 = plt.figure(figsize=(9,9))
ax_dict2 = fig2.subplot_mosaic(
"""
ab
ac
de
""",
gridspec_kw={
    "width_ratios": [7, 1.25], "height_ratios": [3,3,3], "wspace": 0.35, "hspace": 0.45, "left":0.1, "right":0.96
},)
ax_dict2["a"].plot(ekf.dates,ekf.xh1d*ekf.zJ_from_W,'-',label='a posteriori EBM-KF state $\hat{H}_n$ using Zanna (2019)', color=ekf.colorekf,zorder=4)
ax_dict2["a"].fill_between(dates[4:], (ekf.xh1d[4:]-2*ekf.stdPd[4:])*ekf.zJ_from_W, (ekf.xh1d[4:]+2*ekf.stdPd[4:])*ekf.zJ_from_W,label="95% CI from $P_n$ of EBM-KF OHCA state $\hat{H}_n$ using Zanna (2019)", color=ekf.colorstate,alpha=0.25,zorder=2)

plt.sca(ax_dict["a"])
ekf.plot_boilerplate(ax_dict["a"])
plt.yticks(np.arange(286.2,288.4,0.2))
ax_dict["a"].set_ylim(286.3,288.1)
ax_dict["a"].plot(ekf.dates,ekf.xh1s,'-',label='a posteriori EBM-KF GMST state $\hat{T}_n$ using real HadCRUT5', color=ekf.colorekf,zorder=4)
ax_dict["a"].fill_between(dates[4:], ekf.xh1s[4:]-2*ekf.stdP[4:], ekf.xh1s[4:]+2*ekf.stdP[4:],label="95% CI from $P_n$ of EBM-KF state $\hat{T}_n$ using real HadCRUT5", color=ekf.colorstate,alpha=0.25,zorder=2)
xh1ds=copy.deepcopy(ekf.xh1s)
stdPs=copy.deepcopy(ekf.stdP)
xh1d=copy.deepcopy(ekf.xh1d)
stdPd=copy.deepcopy(ekf.stdPd)
avglen=30
avglend2=0
all_xhats=np.zeros([sms, smyrs])
all_xhatd=np.zeros([sms, smyrs])-4000
all_Ps=np.zeros([sms, smyrs])
all_Pd=np.zeros([sms, smyrs])
all_qqys=np.zeros([sms, smyrs])+350
all_qqyd=np.zeros([sms, smyrs])+350
for i in range(sms):
    cur_temps=twTR[i,:]
##    len_sim=smyrs
    len_sim=0
    try:
        len_sim = list(cur_temps).index(float("nan"))
    except ValueError:
        len_sim=len(cur_temps)
    if(len_sim>174):
        len_sim=174
    observ = np.transpose(np.array([cur_temps[:len_sim],ocn_sims[i,:len_sim]/ekf.zJ_from_W])) ##NOTE - should I englarge the LENS2 OHC by 5/3 to match Zanna obs?
    sim_xavg, this_Ps=ekf.ekf_run(observ,len_sim, retPs=True)

    dateslice=dates[(sum_sims.smdate-sum_sims.sdate):(sum_sims.smdate-sum_sims.sdate+len_sim)]
    datesindices=range(avglend2,smyrs-avglend2)
    
    all_xhats[i,:]=sim_xavg[:,0]
    all_xhatd[i,:]=sim_xavg[:,1]
    all_Ps[i,:]=np.sqrt(np.abs(this_Ps[:,0,0]))
    all_Pd[i,:]=np.sqrt(np.abs(this_Ps[:,1,1]))
    all_qqys[i,:]=(sim_xavg[:,0]-xh1ds[0+avglend2:len_sim-avglend2])/stdPs[0+avglend2:len_sim-avglend2]
    all_qqyd[i,:]=(sim_xavg[:,1]-xh1d[0+avglend2:len_sim-avglend2])/stdPd[0+avglend2:len_sim-avglend2]
    if i==(sms-1):
        ax_dict["a"].plot(dateslice,sim_xavg[:,0],'-',label='a posteriori EBM-KF GMST states $(\hat{T}_n)_j$ using LENS2 sims', color=colorekfsims, linewidth=0.5,zorder=1)
        ax_dict2["a"].plot(dateslice,sim_xavg[:,1]*ekf.zJ_from_W,'-',label='a posteriori EBM-KF OHCA states $(\hat{H}_n)_j$ using LENS2 sims ', color=colorekfsims, linewidth=0.5,zorder=1)
        #plt.plot(dateslice,ocn_sims[i,:len_sim]/ekf.zJ_from_W,'-',label='OHCA from simulations $(\\Psi_n)_j$', color=[52/256,235/256,235/256], linewidth=0.5,zorder=2)
    else:
        ax_dict["a"].plot(dateslice,sim_xavg[:,0],'-', color=sum_sims.colorgen(i,(70,30,30),colorekfsimsl), linewidth=0.5,zorder=1)
        ax_dict2["a"].plot(dateslice,sim_xavg[:,1]*ekf.zJ_from_W,'-', color=sum_sims.colorgen(i,(70,30,30),colorekfsimsl), linewidth=0.5,zorder=1)
        #ax_dict2["a"].plot(dateslice,ocn_sims[i,:len_sim],'-', color=sum_sims.colorgen(i,(70,30,30),(52,235,235)), linewidth=0.5,zorder=2)
all2_xhats=np.where(all_xhats == 0, float("nan"), all_xhats)
all2_qqys=np.where(all_qqys == 350, float("nan"), all_qqys)
all2_xhatd=np.where(all_xhatd == -4000, float("nan"), all_xhatd)
all2_qqyd=np.where(all_qqyd == 350, float("nan"), all_qqyd)
mean_xhats=np.nanmean(all_xhats,axis=0)
mean_xhatd=np.nanmean(all_xhatd,axis=0)
std_xhats=np.nanstd(twTR,axis=0) #/np.sqrt(sms)

print(np.mean(std_xhats))
print(np.mean(stdPs))

ax_dict["a"].plot(ekf.dates[avglend2:smyrs-avglend2],mean_xhats,'-',label='mean simulated EBM-KF GMST state $\overline{(\hat{T}_n)_j}$', color='black',zorder=3)
ax_dict2["a"].plot(ekf.dates[avglend2:smyrs-avglend2],mean_xhatd*ekf.zJ_from_W,'-',label='mean simulated EBM-KF OHCA state $\overline{(\hat{H}_n)_j}$', color='black',zorder=3)

r2 = ekf.r2_score(mean_xhats, ekf.xh1s)
print('r2 score for GMST ens mean vs EBM-KF is', r2)
r2 = ekf.r2_score(mean_xhatd*ekf.zJ_from_W, ekf.xh1d*ekf.zJ_from_W)
print('r2 score for OHCA ens mean vs EBM-KF is', r2)

#plt.fill_between(dates, xh1ds-2*stdPs, xh1ds+2*stdPs,label="95% CI of EBM-KF state from measurements ($2\sqrt{P_{n}}$)", color='seagreen',alpha=0.2,zorder=2)
plt.title("Simulated EBM-KF GMST States vs HadCRUT5 EBM-KF State Uncertainty")
handles, labels = ax_dict["a"].get_legend_handles_labels()
order = [0,3,1,2]
ax_dict["a"].legend([handles[idx] for idx in order],[labels[idx] for idx in order],loc='upper left')
handles, labels = ax_dict2["a"].get_legend_handles_labels()
ax_dict2["a"].legend([handles[idx] for idx in order],[labels[idx] for idx in order],loc='upper left')

mn, mx = ax_dict["a"].get_ylim()
ax1b = ax_dict["a"].twinx()
ax1b.set_yticks(np.arange(13.2,15.2,0.2))
ax1b.set_ylim(mn-273.15, mx-273.15)
ax1b.set_ylabel('Temperature (°C)')
ax1b.yaxis.set_label_coords(0.95, 0.5)

mn, mx = ax_dict2["a"].get_ylim()
ax_dict2["a"].set_ylabel('Heat (zJ)')
ax_dict2["a"].set_xlim([1850,2025])
ax_dict2["a"].set_xlabel("Year")
ax_dict2["a"].set_title("Simulated EBM-KF OHCA States vs Zanna EBM-KF State Uncertainty")
ax1b = ax_dict2["a"].twinx()
ax1b.set_yticks(np.arange(-1,8,1))
ax1b.set_ylim(mn*ekf.zJtomm/10, mx*ekf.zJtomm/10)
ax1b.set_ylabel('Thermosteric Sea Level Rise (cm)')
ax1b.yaxis.set_label_coords(0.95, 0.3)

lpad=0 
ax1=ax_dict["b"]
ax2=ax_dict["c"]

plt.sca(ax2)
stats.probplot(all2_qqys[~np.isnan(all2_qqys)],dist="norm", plot=plt)
ax2.get_lines()[0].set_markerfacecolor(ekf.colorstate)
ax2.get_lines()[0].set_markeredgewidth(0)
ax2.get_lines()[1].set_color(ekf.pcolor)
ax2.get_lines()[1].set_linewidth(1.5)
ax2.get_lines()[0].set_markersize(3.0)
ax2.set_xlabel("Theoretical Quantiles",labelpad=lpad)
ax2.set_title("")

#fig.suptitle("Variance of Simulated EBM-KF States $\\neq$ State Variance ($P$) of Measurements")

plt.sca(ax1)
ax1.hist(all2_qqys[~np.isnan(all2_qqys)],bins=sum_sims.nbins, density=True,color=ekf.colorstate)
xnorm = np.linspace(-3.5, 3.5, 100)
ynorm  = stats.norm.pdf(xnorm)
ax1.plot(xnorm,ynorm, color=ekf.pcolor)
ax1.set_xlim([-6,6])
ax1.set_xlabel("State Std Dev from $P_n$",labelpad=lpad)
ax1.set_ylabel("Probability Density")
ax1.set_title("Real-Sims Dist.")


# Calculate all differences relative to eachother simulation
allall_qqys=np.zeros([sms, smyrs*sms])+50
allabsdfs=np.zeros([sms, smyrs*sms])+50
mean_qqys=np.zeros(sms)
meandfs=np.zeros(sms)
stdev_qqys=np.zeros(sms)

for i in range(sms):
    for j in range(sms):
        if (i==j):
            continue
        else:
            allall_qqys[i,j*smyrs:(j+1)*smyrs]= (all_xhats[i,:] - all_xhats[j,:])/all_Ps[i,:]
            allabsdfs[i,j*smyrs:(j+1)*smyrs]= (all_xhats[i,:] - all_xhats[j,:])
allall2_qqys=np.where(allall_qqys == 50, float("nan"), allall_qqys)
all2absdfs=np.where(allabsdfs == 50, float("nan"), allabsdfs)
for i in range(sms):
    mean_qqys[i]=np.nanmean(allall2_qqys[i,:])
    meandfs[i]=np.nanmean(all2absdfs[i,:])
    stdev_qqys[i]=np.nanstd(allall2_qqys[i,:])
ax3=ax_dict["e"]
plt.sca(ax3)
ax3.hist(allall2_qqys[0,~np.isnan(allall2_qqys[0,:])],bins=sum_sims.nbins, density=True,color=colorekfsims)
xnorm = np.linspace(-3.5, 3.5, 100)
ynorm  = stats.norm.pdf(xnorm)
ax3.plot(xnorm,ynorm, color=ekf.pcolor)
ax3.set_xlim([-6,6])
ax3.set_xlabel("State Std Dev from $(P_n)_0$",labelpad=lpad)
ax3.set_ylabel("Probability Density")
ax3.set_title("A Sim-Sims Dist.")

ax4=ax_dict["d"]
ax4.plot(mean_qqys[1:],stdev_qqys[1:],'.',color=colorekfsims,label="Sim-Sims comparisons")
ax4.plot(mean_qqys[0],stdev_qqys[0],'.',color=colorekfsims,markeredgewidth=1, markeredgecolor='k',label="A Sim-Sims comparison (e)")
ax4.plot(np.nanmean(allall2_qqys),np.nanstd(allall2_qqys),'x',color='black',label="Centroid of Sim-Sims differences")
ax4.plot(np.nanmean(all2_qqys),np.nanstd(all2_qqys),'*',markeredgewidth=1, markeredgecolor='k',color=ekf.colorstate,label="Real-Sims comparison (b)")
ax4.set_xlabel("LENS2 States Ensemble Mean Bias in Std Devs from $(P_n)_{i}$ or $P_n$")
ax4.set_ylabel("LENS2 States Ensemble Std Dev \n in Std Devs from $(P_n)_{i}$ or $P_n$")
ax4.yaxis.set_label_coords(-.07, .5)
ax4.plot([0,0],[0,3],'-',color=ekf.pcolor,zorder=0)
ax4.plot([-2,3],[1,1],'-',color=ekf.pcolor,zorder=0)
ax4.set_ylim([0.4,1.8])
ax4.set_xlim([-1.1,1.1])
ax4.legend(loc="best",fontsize="8")
ax4.set_title("Distribution of Simulated EBM-KF States relative to Single $\hat{T}_n$ and $P_n$")
print(np.mean(abs(mean_qqys)),np.max(mean_qqys),np.min(mean_qqys))
print(np.mean(abs(meandfs)),np.max(meandfs),np.min(meandfs),'K')

axlabels=['a)','b)','c)','d)','e)']
for i in range(len(ax_dict)):
    label=axlabels[i]
    ax=ax_dict[label[0]]
# label physical distance to the left and up:
    trans = mtransforms.ScaledTranslation(-25/72, 5/72, fig.dpi_scale_trans) #-20/72, 7/72
    ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
        fontsize='large', va='bottom') #, fontfamily='serif')

fig.savefig("ekf_on_sims.pdf", format="pdf")
fig.savefig("ekf_on_sims.png", dpi=400,format="png")
print(np.nanstd(allall2_qqys))
print(np.nanmean(all_Ps))

ax1=ax_dict2["b"]
ax2=ax_dict2["c"]

plt.sca(ax2)
stats.probplot(all2_qqyd[~np.isnan(all2_qqyd)],dist="norm", plot=plt)
ax2.get_lines()[0].set_markerfacecolor(ekf.colorstate)
ax2.get_lines()[0].set_markeredgewidth(0)
ax2.get_lines()[1].set_color(ekf.pcolor)
ax2.get_lines()[1].set_linewidth(1.5)
ax2.get_lines()[0].set_markersize(3.0)
ax2.set_xlabel("Theoretical Quantiles",labelpad=lpad)
ax2.set_title("")

#fig.suptitle("Variance of Simulated EBM-KF States $\\neq$ State Variance ($P$) of Measurements")

plt.sca(ax1)
ax1.hist(all2_qqyd[~np.isnan(all2_qqyd)],bins=sum_sims.nbins, density=True,color=ekf.colorstate)
xnorm = np.linspace(-3.5, 3.5, 100)
ynorm  = stats.norm.pdf(xnorm)
ax1.plot(xnorm,ynorm, color=ekf.pcolor)
ax1.set_xlim([-15,10])
ax1.set_xlabel("State Std Dev from $P_n$",labelpad=lpad)
ax1.set_ylabel("Probability Density")
ax1.set_title("Real-Sims Dist.")

# Calculate all differences relative to eachother simulation
allall_qqyd=np.zeros([sms, smyrs*sms])+350
allabsdfd=np.zeros([sms, smyrs*sms])+50
mean_qqyd=np.zeros(sms)
stdev_qqyd=np.zeros(sms)
meandfd=np.zeros(sms)
for i in range(sms):
    for j in range(sms):
        if (i==j):
            continue
        else:
            allall_qqyd[i,j*smyrs:(j+1)*smyrs]= (all_xhatd[i,:] - all_xhatd[j,:])/all_Pd[i,:]
            allabsdfd[i,j*smyrs:(j+1)*smyrs]= (all_xhatd[i,:] - all_xhatd[j,:])
allall2_qqyd=np.where(allall_qqyd == 350, float("nan"), allall_qqyd)
all2absdfd=np.where(allabsdfd == 50, float("nan"), allabsdfd)
for i in range(sms):
    mean_qqyd[i]=np.nanmean(allall2_qqyd[i,:])
    stdev_qqyd[i]=np.nanstd(allall2_qqyd[i,:])
    meandfd[i]=np.nanmean(all2absdfd[i,:])
ax3=ax_dict2["e"]
plt.sca(ax3)
ax3.hist(allall2_qqyd[0,~np.isnan(allall2_qqyd[0,:])],bins=sum_sims.nbins, density=True,color=colorekfsims)
xnorm = np.linspace(-3.5, 3.5, 100)
ynorm  = stats.norm.pdf(xnorm)
ax3.plot(xnorm,ynorm, color=ekf.pcolor)
ax3.set_xlim([-6,6])
ax3.set_xlabel("State Std Dev from $(P_n)_{0}}$",labelpad=lpad)
ax3.set_ylabel("Probability Density")
ax3.set_title("A Sim-Sims Dist.")

ax4=ax_dict2["d"]
ax4.plot(mean_qqyd[1:],stdev_qqyd[1:],'.',color=colorekfsims,label="Sim-Sims comparisons")
ax4.plot(mean_qqyd[0],stdev_qqyd[0],'.',color=colorekfsims,markeredgewidth=1, markeredgecolor='k',label="A Sim-Sims comparison (e)")
ax4.plot(np.nanmean(allall2_qqyd),np.nanstd(allall2_qqyd),'x',color='black',label="Centroid of Sim-Sims differences")
ax4.plot(np.nanmean(all2_qqyd),np.nanstd(all2_qqyd),'*',markeredgewidth=1, markeredgecolor='k',color=ekf.colorstate,label="Real-Sims comparison (b)")
ax4.set_xlabel("LENS2 States Ensemble Mean Bias in Std Devs from $(P_n)_{i}$ or $P_n$")
ax4.set_ylabel("LENS2 States Ensemble Std Dev \n in Std Devs from $(P_n)_{i}$ or $P_n$")
ax4.yaxis.set_label_coords(-.07, .5)
ax4.plot([0,0],[0,6],'-',color=ekf.pcolor,zorder=0)
ax4.plot([-3,3],[1,1],'-',color=ekf.pcolor,zorder=0)
ax4.set_ylim([0.4,5])
ax4.set_xlim([-2.6,2.6])
ax4.legend(loc="best",fontsize="8")
ax4.set_title("Distribution of Simulated EBM-KF States relative to Single $\hat{T}_n$ and $P_n$") 
print(np.mean(abs(mean_qqyd)),np.max(mean_qqyd),np.min(mean_qqyd))
print(np.mean(abs(meandfd))*ekf.zJ_from_W,np.max(meandfd)*ekf.zJ_from_W,np.min(meandfd)*ekf.zJ_from_W,'ZJ')

axlabels=['a)','b)','c)','d)','e)']
for i in range(len(ax_dict)):
    label=axlabels[i]
    ax=ax_dict2[label[0]]
# label physical distance to the left and up:
    trans = mtransforms.ScaledTranslation(-45/72, 5/72, fig.dpi_scale_trans) #-20/72, 7/72 #The offset is messed up between -15 pdf and -45 png: pdf needs far less
    ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
        fontsize='large', va='bottom') #, fontfamily='serif')

#fig2.savefig("ekf_on_sims_OHCA.pdf", format="pdf")
fig2.savefig("ekf_on_sims_OHCA.png", dpi=400,format="png")


plt.show()

